-
Notifications
You must be signed in to change notification settings - Fork 3k
[llm] support tensorwise fp8/int8 training #10612
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: develop
Are you sure you want to change the base?
Conversation
Thanks for your contribution! |
Codecov ReportAttention: Patch coverage is
❌ Your patch check has failed because the patch coverage (16.57%) is below the target coverage (80.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## develop #10612 +/- ##
===========================================
- Coverage 46.91% 46.90% -0.02%
===========================================
Files 799 800 +1
Lines 132460 132519 +59
===========================================
+ Hits 62148 62157 +9
- Misses 70312 70362 +50 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
@@ -478,8 +525,8 @@ def load_state_dict( | |||
scale_dict.update(res_scale_dict) | |||
|
|||
if device == "cpu": | |||
for k in list(state_dict.keys()): | |||
with device_guard(): | |||
with device_guard(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
减少反复set_device增加耗时
"weight_only_int4", | ||
"weight_only_int8", | ||
] | ||
elif isinstance(config.quantization_config.weight_quantize_algo, dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
weight_only_int8不支持不同TP分片共享同一个scale,暂不支持wint8权重灵活转化TP策略
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
post_quantize 代表先TP切分权重再量化(针对wint4/wint8)
@@ -2537,6 +2615,7 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs): | |||
# load pt weights early so that we know which dtype to init the model under | |||
if not is_sharded and state_dict is None: | |||
# 4. loading non-sharded ckpt from the state dict | |||
# Quantization: Loading non-sharded ckpt does not support saving with merge_tensor_parallel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
暂时不考虑非safetensor权重的量化加载和保存
return block | ||
|
||
|
||
def create_hadamard_matrix(block_size, dtype): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
和前面random_hadamard_matrix
的区别是什么
if getattr(infohub, "hadamard") is None: | ||
setattr(infohub, "hadamard", {}) | ||
|
||
if block_size in infohub.hadamard: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hadamard_matrix 没有默认值的话,没有命中该分支会出问题
@@ -2107,16 +2109,6 @@ def get_optimizer_cls_and_kwargs(args: TrainingArguments) -> Tuple[Any, Any]: | |||
|
|||
optimizer_cls = AdamWCustom | |||
optimizer_kwargs.update(adam_kwargs) | |||
elif args.optim == OptimizerNames.ADAMW_16BIT_MOMENT: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
为什么需要去掉这两种adamw实现
@@ -318,8 +318,6 @@ class OptimizerNames(ExplicitEnum): | |||
ADAFACTOR = "adafactor" | |||
ADAMW_MINI = "adamw_mini" | |||
ADAMW_CUSTOM = "adamw_custom" | |||
ADAMW_16BIT_MOMENT = "adamw_16bit_moment" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这两种在以往的配置中没有被使用到吗?
@@ -868,6 +868,10 @@ class TrainingArguments: | |||
default="adamw", | |||
metadata={"help": "The optimizer to use."}, | |||
) | |||
use_lowprecision_moment: bool = field( | |||
default=False, | |||
metadata={"help": "AdamW use lowbit moment as parameter."}, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的lowbit是指多少位,什么情况下建议开启,开启后的影响是什么,需要明确解释下
@@ -996,6 +1000,10 @@ class TrainingArguments: | |||
default=False, | |||
metadata={"help": "Offload optimizer after optimizer.step()"}, | |||
) | |||
tensorwise_offload_optimizer: Optional[bool] = field( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
help信息没解释清楚,为什么需要这个
@@ -445,7 +452,9 @@ def compute_metrics_do_generation(eval_preds): | |||
gen_args=gen_args, | |||
data_args=data_args, | |||
) | |||
trainable_parameters = [p for p in model.parameters() if not p.stop_gradient] | |||
trainable_parameters = [ | |||
p for p in model.parameters() if not p.stop_gradient or ("quantization_linear" in p.name and "w_1" in p.name) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的hardcode可以避免吗?或者如何保证一定生效?至少需要有log提示
if self.weight_quantize_algo not in ["fp8linear", "a8w4linear", "fp8linear"]: | ||
self.quant_scale.is_distributed = False | ||
else: | ||
self.quant_scale.is_distributed = True if self.is_mp else False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
要考虑DP吗?
scale = paddle.max(paddle.abs(target_x)) / qmax | ||
if group is not None: | ||
paddle.distributed.all_reduce(scale, op=paddle.distributed.ReduceOp.MAX, group=group, sync_op=True) | ||
if state < quantization_config.apply_online_actscale_step: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的online scaling是和delayed scaling对应的吗?不知道这个参数影响了什么?建议给用户解释下
PR types
New features
PR changes
APIs
Description
新增支持功能:
1.新增权重scale和激活scale all_reduce_max,以支持不同TP和数据并行策略切分
2. 支持DP+TP+PP+Sharding stage1训练FP8/INT8训练,使用Unified Checkpoint对权重、optimizer存储
3. 哈达玛矩阵乘改用对角block 哈达玛矩阵
4. 统一FP8/INT8训练代码逻辑
5. 新增支持Triton版本FP8权重AdamW优化器(含bf16 moment和offload功能)
6. 支持主干模型FP8/INT8 LoRA
后续PR待支持功能:
1.目前FP8权重使用paddle.int8表示np.int8存储,后续修改为float8表示(待框架支持fp8 set_value和concat)
2. 对FP8/INT8 quant-matmul-dequant 过程进行性能加速和对Moe结构进行加速适配
3.FP8/INT8训练支持Sharding stage2/3(PP仅支持stage1 优先级不高)